{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/ucav/PythonProjects/fly-craft-examples\n"
     ]
    }
   ],
   "source": [
    "from pathlib import Path\n",
    "import sys\n",
    "import argparse\n",
    "from copy import deepcopy\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from ray.util.multiprocessing import Pool\n",
    "\n",
    "from stable_baselines3.ppo import MlpPolicy\n",
    "from stable_baselines3.common.vec_env import VecCheckNan\n",
    "\n",
    "from flycraft.env import FlyCraftEnv\n",
    "from flycraft.utils.load_config import load_config\n",
    "\n",
    "PROJECT_ROOT_DIR = Path().absolute().parent\n",
    "if str(PROJECT_ROOT_DIR.absolute()) not in sys.path:\n",
    "    sys.path.append(str(PROJECT_ROOT_DIR.absolute()))\n",
    "\n",
    "print(PROJECT_ROOT_DIR)\n",
    "\n",
    "from utils_my.models.ppo_with_bc_loss import PPOWithBCLoss\n",
    "from utils_my.sb3.vec_env_helper import get_vec_env, make_env\n",
    "from utils_my.smoothness.smoothness_measure import smoothness_measure_by_delta, smoothness_measure_by_fft\n",
    "from utils_my.sb3.my_wrappers import ScaledActionWrapper, ScaledObservationWrapper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "({'observation': array([0.5      , 0.5145179, 0.5      , 0.2      , 0.5      , 0.5      ,\n",
       "         0.5      , 0.25     ], dtype=float32),\n",
       "  'desired_goal': array([0.20118216, 0.5500515 , 0.44069326], dtype=float32),\n",
       "  'achieved_goal': array([0.2, 0.5, 0.5], dtype=float32)},\n",
       " {'plane_state': {'lef': 0.0,\n",
       "   'npos': 0.0,\n",
       "   'epos': 0.0,\n",
       "   'h': 5000.0,\n",
       "   'alpha': 2.613220523094039,\n",
       "   'beta': 0.0,\n",
       "   'phi': 0.0,\n",
       "   'theta': 2.613220523094039,\n",
       "   'psi': 0.0,\n",
       "   'p': 0.0,\n",
       "   'q': 0.0,\n",
       "   'r': 0.0,\n",
       "   'v': 200.0,\n",
       "   'vn': 200.00000000000003,\n",
       "   've': 0.0,\n",
       "   'vh': 1.7763568394002505e-15,\n",
       "   'nx': 0.04559349104912459,\n",
       "   'ny': 0.0,\n",
       "   'nz': 0.998960076119232,\n",
       "   'ele': -1.327651593874395,\n",
       "   'ail': 0.0,\n",
       "   'rud': 0.0,\n",
       "   'thrust': 0.0,\n",
       "   'lon': 122.425,\n",
       "   'lat': 31.425,\n",
       "   'mu': 5.088887490067862e-16,\n",
       "   'chi': 0.0}})"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "env = FlyCraftEnv(\n",
    "    config_file=PROJECT_ROOT_DIR / \"configs\" / \"env\" / \"env_config_for_ppo_easy.json\",\n",
    "    custom_config={}\n",
    ")\n",
    "scaled_obs_env = ScaledObservationWrapper(env)\n",
    "scaled_act_env = ScaledActionWrapper(scaled_obs_env)\n",
    "scaled_act_env.reset(seed=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "False 201.18216247002567 9.009273926518706 -21.350423236821975\n"
     ]
    }
   ],
   "source": [
    "print(env.task.goal_sampler.use_fixed_goal, env.task.goal_sampler.goal_v, env.task.goal_sampler.goal_mu, env.task.goal_sampler.goal_chi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'v': 170.34552406761497, 'mu': -4.753733191163009, 'chi': 15.021880357803155}"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scaled_act_env.unwrapped.task.goal_sampler.sample_goal()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "scaled_act_env.unwrapped.task.goal_sampler.use_fixed_goal = True\n",
    "scaled_act_env.unwrapped.task.goal_sampler.goal_v = 200\n",
    "scaled_act_env.unwrapped.task.goal_sampler.goal_mu = 10\n",
    "scaled_act_env.unwrapped.task.goal_sampler.goal_chi = 30"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'v': 200, 'mu': 10, 'chi': 30}"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "env.task.goal_sampler.sample_goal()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "disc",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
